Spark MLlib (Decision tree , Random forests , Gradient Boosted) --Regression

Dataset kaggle-bikesharing

https://www.kaggle.com/c/bike-sharing-demand


In [1]:
sc




Out[1]:
org.apache.spark.SparkContext@ba0018e

In [2]:
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression._




Out[2]:
org.apache.spark.SparkContext@ba0018e

In [7]:
val rawDatain = sc.textFile("/Users/wy/Desktop/advance_Spark/train.csv")




Out[7]:
/Users/wy/Desktop/advance_Spark/train.csv MapPartitionsRDD[7] at textFile at <console>:21

filter csv header


In [8]:
val header = rawDatain.first




Out[8]:
datetime,season,holiday,workingday,weather,temp,atemp,humidity,windspeed,casual,registered,count

In [9]:
val rawData = rawDatain.filter(_ != header)




Out[9]:
MapPartitionsRDD[8] at filter at <console>:25

In [10]:
rawData.take(2)




Out[10]:
Array(2011-01-01 00:00:00,1,0,0,1,9.84,14.395,81,0,3,13,16, 2011-01-01 01:00:00,1,0,0,1,9.02,13.635,80,0,8,32,40)

remove datatime and

Notice MLlib datatype Labeled point https://spark.apache.org/docs/latest/mllib-data-types.html#labeled-point


In [11]:
val data = rawData.map{ line =>
	val value = line.split(",")
	val values = Array(value(1).toDouble,value(2).toDouble,value(3).toDouble,value(4).toDouble,value(5).toDouble,value(6).toDouble,value(7).toDouble,value(8).toDouble,value(9).toDouble,value(10).toDouble,value(11).toDouble)
	val featureVector = Vectors.dense(values.init)
	val label = values.last-1
	LabeledPoint(label,featureVector)
}




Out[11]:
MapPartitionsRDD[9] at map at <console>:27

In [12]:
data.take(2)




Out[12]:
Array((15.0,[1.0,0.0,0.0,1.0,9.84,14.395,81.0,0.0,3.0,13.0]), (39.0,[1.0,0.0,0.0,1.0,9.02,13.635,80.0,0.0,8.0,32.0]))

In [13]:
val Array(trainData, testData) = data.randomSplit(Array(0.8, 0.2))




Out[13]:
PartitionwiseSampledRDD[11] at randomSplit at <console>:29

In [14]:
trainData.cache()
testData.cache()




Out[14]:
PartitionwiseSampledRDD[11] at randomSplit at <console>:29

In [15]:
import org.apache.spark.mllib.evaluation._
import org.apache.spark.mllib.tree._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd._




Out[15]:
PartitionwiseSampledRDD[11] at randomSplit at <console>:29

DecisionTree (Regression)


In [16]:
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 32




Out[16]:
32

In [17]:
val model = DecisionTree.trainRegressor(trainData, categoricalFeaturesInfo, impurity,maxDepth, maxBins)




Out[17]:
DecisionTreeModel regressor of depth 5 with 63 nodes

In [19]:
val labelsAndPredictions = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
  // (point.features,point.label, prediction)
}




Out[19]:
MapPartitionsRDD[31] at map at <console>:45

In [20]:
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression tree model:\n" + model.toDebugString)


Test Mean Squared Error = 705.3931281033362
Learned regression tree model:
DecisionTreeModel regressor of depth 5 with 63 nodes
  If (feature 9 <= 207.0)
   If (feature 9 <= 87.0)
    If (feature 9 <= 35.0)
     If (feature 9 <= 14.0)
      If (feature 9 <= 6.0)
       Predict: 3.627329192546584
      Else (feature 9 > 6.0)
       Predict: 11.460144927536232
     Else (feature 9 > 14.0)
      If (feature 9 <= 27.0)
       Predict: 23.94160583941606
      Else (feature 9 > 27.0)
       Predict: 36.08270676691729
    Else (feature 9 > 35.0)
     If (feature 9 <= 66.0)
      If (feature 9 <= 46.0)
       Predict: 46.820422535211264
      Else (feature 9 > 46.0)
       Predict: 66.06410256410257
     Else (feature 9 > 66.0)
      If (feature 8 <= 25.0)
       Predict: 85.62679425837321
      Else (feature 8 > 25.0)
       Predict: 116.21428571428571
   Else (feature 9 > 87.0)
    If (feature 9 <= 141.0)
     If (feature 8 <= 37.0)
      If (feature 9 <= 108.0)
       Predict: 111.73198198198199
      Else (feature 9 > 108.0)
       Predict: 139.7277777777778
     Else (feature 8 > 37.0)
      If (feature 8 <= 98.0)
       Predict: 171.31791907514452
      Else (feature 8 > 98.0)
       Predict: 259.625
    Else (feature 9 > 141.0)
     If (feature 8 <= 62.0)
      If (feature 9 <= 165.0)
       Predict: 179.89086859688197
      Else (feature 9 > 165.0)
       Predict: 214.00183486238532
     Else (feature 8 > 62.0)
      If (feature 8 <= 124.0)
       Predict: 259.0870967741935
      Else (feature 8 > 124.0)
       Predict: 345.54285714285714
  Else (feature 9 > 207.0)
   If (feature 9 <= 433.0)
    If (feature 8 <= 177.0)
     If (feature 9 <= 289.0)
      If (feature 8 <= 82.0)
       Predict: 281.3866481223922
      Else (feature 8 > 82.0)
       Predict: 365.67905405405406
     Else (feature 9 > 289.0)
      If (feature 8 <= 62.0)
       Predict: 378.82254196642685
      Else (feature 8 > 62.0)
       Predict: 451.8471074380165
    Else (feature 8 > 177.0)
     If (feature 9 <= 289.0)
      If (feature 9 <= 262.0)
       Predict: 452.47169811320754
      Else (feature 9 > 262.0)
       Predict: 499.81481481481484
     Else (feature 9 > 289.0)
      If (feature 9 <= 365.0)
       Predict: 567.4622641509434
      Else (feature 9 > 365.0)
       Predict: 650.7555555555556
   Else (feature 9 > 433.0)
    If (feature 9 <= 542.0)
     If (feature 8 <= 48.0)
      If (feature 8 <= 22.0)
       Predict: 492.4848484848485
      Else (feature 8 > 22.0)
       Predict: 510.8142857142857
     Else (feature 8 > 48.0)
      If (feature 8 <= 177.0)
       Predict: 562.1825396825396
      Else (feature 8 > 177.0)
       Predict: 692.7272727272727
    Else (feature 9 > 542.0)
     If (feature 8 <= 62.0)
      If (feature 6 <= 57.0)
       Predict: 636.5333333333333
      Else (feature 6 > 57.0)
       Predict: 680.8555555555556
     Else (feature 8 > 62.0)
      If (feature 8 <= 98.0)
       Predict: 748.8987341772151
      Else (feature 8 > 98.0)
       Predict: 802.3898305084746



In [90]:
for (impurity <- Array("variance")){
    for (maxDepth <- Array(3, 5)){
        for (maxBins <- Array(16, 32)){
            val model = DecisionTree.trainRegressor(trainData,Map[Int,Int](), impurity, maxDepth, maxBins)
            val labelsAndPredictions = testData.map { point =>
                val prediction = model.predict(point.features)
               (point.label, prediction)
                // (point.features,point.label, prediction)
            }
            val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
            println(((maxDepth, maxBins), testMSE))
        }
    }
}


((3,16),2487.090266972155)
((3,32),2031.7394162627231)
((5,16),1059.674863424604)
((5,32),705.3931281033362)


In [111]:
// Save model
// model.save(sc, "/Users/wy/Desktop/advance_Spark/model")
// Load model
// val sameModel = DecisionTreeModel.load(sc, "/Users/wy/Desktop/advance_Spark/model")
// avoid incomplete useless val tmp =1 
val tmp =1




Out[111]:
1

RandomForest (Regression)


In [91]:
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils





In [92]:
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 3 // Use more in practice.
val featureSubsetStrategy = "auto" // Let the algorithm choose.
val impurity = "variance"
val maxDepth = 4
val maxBins = 32




Out[92]:
32

In [93]:
val model = RandomForest.trainRegressor(trainData, categoricalFeaturesInfo,numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)




Out[93]:
TreeEnsembleModel regressor with 3 trees

In [94]:
val labelsAndPredictions = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}




Out[94]:
MapPartitionsRDD[278] at map at <console>:52

In [99]:
labelsAndPredictions.take(2)




Out[99]:
Array((39.0,89.70561804026455), (31.0,32.617074575331))

In [108]:
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression forest model:\n" + model.toDebugString)



Test Mean Squared Error = 2150.9217640461657
Learned regression forest model:
TreeEnsembleModel regressor with 3 trees

  Tree 0:
    If (feature 8 <= 22.0)
     If (feature 0 <= 3.0)
      If (feature 8 <= 6.0)
       If (feature 9 <= 76.0)
        Predict: 20.4469305794607
       Else (feature 9 > 76.0)
        Predict: 137.444099378882
      Else (feature 8 > 6.0)
       If (feature 0 <= 1.0)
        Predict: 155.48896434634975
       Else (feature 0 > 1.0)
        Predict: 124.2315668202765
     Else (feature 0 > 3.0)
      If (feature 2 <= 0.0)
       If (feature 9 <= 66.0)
        Predict: 29.18421052631579
       Else (feature 9 > 66.0)
        Predict: 116.30281690140845
      Else (feature 2 > 0.0)
       If (feature 8 <= 7.0)
        Predict: 58.252252252252255
       Else (feature 8 > 7.0)
        Predict: 247.26012793176972
    Else (feature 8 > 22.0)
     If (feature 9 <= 289.0)
      If (feature 5 <= 30.305)
       If (feature 7 <= 11.0014)
        Predict: 210.3319209039548
       Else (feature 7 > 11.0014)
        Predict: 240.0751592356688
      Else (feature 5 > 30.305)
       If (feature 9 <= 165.0)
        Predict: 177.10721649484537
       Else (feature 9 > 165.0)
        Predict: 315.2970550576184
     Else (feature 9 > 289.0)
      If (feature 8 <= 71.0)
       If (feature 9 <= 542.0)
        Predict: 436.7192118226601
       Else (feature 9 > 542.0)
        Predict: 685.4036697247707
      Else (feature 8 > 71.0)
       If (feature 2 <= 0.0)
        Predict: 549.636690647482
       Else (feature 2 > 0.0)
        Predict: 627.1045296167248
  Tree 1:
    If (feature 9 <= 207.0)
     If (feature 9 <= 87.0)
      If (feature 9 <= 35.0)
       If (feature 9 <= 14.0)
        Predict: 7.064764841942945
       Else (feature 9 > 14.0)
        Predict: 27.89620253164557
      Else (feature 9 > 35.0)
       If (feature 9 <= 55.0)
        Predict: 53.377391304347825
       Else (feature 9 > 55.0)
        Predict: 84.10902255639098
     Else (feature 9 > 87.0)
      If (feature 9 <= 141.0)
       If (feature 9 <= 119.0)
        Predict: 126.27770859277709
       Else (feature 9 > 119.0)
        Predict: 161.79807692307693
      Else (feature 9 > 141.0)
       If (feature 9 <= 165.0)
        Predict: 192.43760683760684
       Else (feature 9 > 165.0)
        Predict: 239.4345165238678
    Else (feature 9 > 207.0)
     If (feature 8 <= 82.0)
      If (feature 4 <= 14.76)
       If (feature 2 <= 0.0)
        Predict: 285.9047619047619
       Else (feature 2 > 0.0)
        Predict: 355.94117647058823
      Else (feature 4 > 14.76)
       If (feature 9 <= 433.0)
        Predict: 333.0328587075575
       Else (feature 9 > 433.0)
        Predict: 599.9391891891892
     Else (feature 8 > 82.0)
      If (feature 6 <= 67.0)
       If (feature 2 <= 0.0)
        Predict: 477.2954091816367
       Else (feature 2 > 0.0)
        Predict: 565.9469696969697
      Else (feature 6 > 67.0)
       If (feature 0 <= 1.0)
        Predict: 472.0
       Else (feature 0 > 1.0)
        Predict: 416.6376811594203
  Tree 2:
    If (feature 9 <= 207.0)
     If (feature 8 <= 22.0)
      If (feature 8 <= 5.0)
       If (feature 8 <= 1.0)
        Predict: 17.04337899543379
       Else (feature 8 > 1.0)
        Predict: 49.50809061488673
      Else (feature 8 > 5.0)
       If (feature 8 <= 13.0)
        Predict: 85.73168724279836
       Else (feature 8 > 13.0)
        Predict: 117.93499308437067
     Else (feature 8 > 22.0)
      If (feature 9 <= 141.0)
       If (feature 9 <= 98.0)
        Predict: 115.70446735395188
       Else (feature 9 > 98.0)
        Predict: 168.62425447316105
      Else (feature 9 > 141.0)
       If (feature 8 <= 71.0)
        Predict: 213.9335260115607
       Else (feature 8 > 71.0)
        Predict: 282.31775700934577
    Else (feature 9 > 207.0)
     If (feature 8 <= 82.0)
      If (feature 9 <= 433.0)
       If (feature 4 <= 13.12)
        Predict: 289.0174418604651
       Else (feature 4 > 13.12)
        Predict: 330.49524714828897
      Else (feature 9 > 433.0)
       If (feature 9 <= 542.0)
        Predict: 523.684
       Else (feature 9 > 542.0)
        Predict: 687.9022988505748
     Else (feature 8 > 82.0)
      If (feature 2 <= 0.0)
       If (feature 4 <= 17.22)
        Predict: 412.11764705882354
       Else (feature 4 > 17.22)
        Predict: 470.1958333333333
      Else (feature 2 > 0.0)
       If (feature 9 <= 542.0)
        Predict: 440.4467005076142
       Else (feature 9 > 542.0)
        Predict: 783.8932038834952


In [112]:
// Save and load model
// model.save(sc, "myModelPath")
// val sameModel = RandomForestModel.load(sc, "myModelPath")
// avoid incomplete useless val tmp =1 
val tmp =1




Out[112]:
1

GradientBoostedTrees (Regression)


In [110]:
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel





In [113]:
val boostingStrategy = BoostingStrategy.defaultParams("Regression")
boostingStrategy.numIterations = 3 
boostingStrategy.treeStrategy.maxDepth = 5
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()




Out[113]:
Map()

In [114]:
val model = GradientBoostedTrees.train(trainData, boostingStrategy)




Out[114]:
TreeEnsembleModel regressor with 3 trees

In [115]:
val labelsAndPredictions = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}




Out[115]:
MapPartitionsRDD[343] at map at <console>:67

In [116]:
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression GBT model:\n" + model.toDebugString)



Test Mean Squared Error = 687.4423958710161
Learned regression GBT model:
TreeEnsembleModel regressor with 3 trees

  Tree 0:
    If (feature 9 <= 207.0)
     If (feature 9 <= 87.0)
      If (feature 9 <= 35.0)
       If (feature 9 <= 14.0)
        If (feature 9 <= 6.0)
         Predict: 3.627329192546584
        Else (feature 9 > 6.0)
         Predict: 11.460144927536232
       Else (feature 9 > 14.0)
        If (feature 9 <= 27.0)
         Predict: 23.94160583941606
        Else (feature 9 > 27.0)
         Predict: 36.08270676691729
      Else (feature 9 > 35.0)
       If (feature 9 <= 66.0)
        If (feature 9 <= 46.0)
         Predict: 46.820422535211264
        Else (feature 9 > 46.0)
         Predict: 66.06410256410257
       Else (feature 9 > 66.0)
        If (feature 8 <= 25.0)
         Predict: 85.62679425837321
        Else (feature 8 > 25.0)
         Predict: 116.21428571428571
     Else (feature 9 > 87.0)
      If (feature 9 <= 141.0)
       If (feature 8 <= 37.0)
        If (feature 9 <= 108.0)
         Predict: 111.73198198198199
        Else (feature 9 > 108.0)
         Predict: 139.7277777777778
       Else (feature 8 > 37.0)
        If (feature 8 <= 98.0)
         Predict: 171.31791907514452
        Else (feature 8 > 98.0)
         Predict: 259.625
      Else (feature 9 > 141.0)
       If (feature 8 <= 62.0)
        If (feature 9 <= 165.0)
         Predict: 179.89086859688197
        Else (feature 9 > 165.0)
         Predict: 214.00183486238532
       Else (feature 8 > 62.0)
        If (feature 8 <= 124.0)
         Predict: 259.0870967741935
        Else (feature 8 > 124.0)
         Predict: 345.54285714285714
    Else (feature 9 > 207.0)
     If (feature 9 <= 433.0)
      If (feature 8 <= 177.0)
       If (feature 9 <= 289.0)
        If (feature 8 <= 82.0)
         Predict: 281.3866481223922
        Else (feature 8 > 82.0)
         Predict: 365.67905405405406
       Else (feature 9 > 289.0)
        If (feature 8 <= 62.0)
         Predict: 378.82254196642685
        Else (feature 8 > 62.0)
         Predict: 451.8471074380165
      Else (feature 8 > 177.0)
       If (feature 9 <= 289.0)
        If (feature 9 <= 262.0)
         Predict: 452.47169811320754
        Else (feature 9 > 262.0)
         Predict: 499.81481481481484
       Else (feature 9 > 289.0)
        If (feature 9 <= 365.0)
         Predict: 567.4622641509434
        Else (feature 9 > 365.0)
         Predict: 650.7555555555556
     Else (feature 9 > 433.0)
      If (feature 9 <= 542.0)
       If (feature 8 <= 48.0)
        If (feature 8 <= 22.0)
         Predict: 492.4848484848485
        Else (feature 8 > 22.0)
         Predict: 510.8142857142857
       Else (feature 8 > 48.0)
        If (feature 8 <= 177.0)
         Predict: 562.1825396825396
        Else (feature 8 > 177.0)
         Predict: 692.7272727272727
      Else (feature 9 > 542.0)
       If (feature 8 <= 62.0)
        If (feature 6 <= 57.0)
         Predict: 636.5333333333333
        Else (feature 6 > 57.0)
         Predict: 680.8555555555556
       Else (feature 8 > 62.0)
        If (feature 8 <= 98.0)
         Predict: 748.8987341772151
        Else (feature 8 > 98.0)
         Predict: 802.3898305084746
  Tree 1:
    If (feature 8 <= 22.0)
     If (feature 9 <= 141.0)
      If (feature 8 <= 10.0)
       If (feature 9 <= 46.0)
        If (feature 8 <= 4.0)
         Predict: 3.310401153741886
        Else (feature 8 > 4.0)
         Predict: -5.421068157817625
       Else (feature 9 > 46.0)
        If (feature 8 <= 5.0)
         Predict: 20.567192706268603
        Else (feature 8 > 5.0)
         Predict: 9.295510080826809
      Else (feature 8 > 10.0)
       If (feature 9 <= 87.0)
        If (feature 8 <= 16.0)
         Predict: -9.95230545737753
        Else (feature 8 > 16.0)
         Predict: -19.325663641916037
       Else (feature 9 > 87.0)
        If (feature 9 <= 130.0)
         Predict: 3.0991419991420166
        Else (feature 9 > 130.0)
         Predict: -21.744444444444454
     Else (feature 9 > 141.0)
      If (feature 9 <= 365.0)
       If (feature 9 <= 289.0)
        If (feature 9 <= 262.0)
         Predict: 46.16258052826596
        Else (feature 9 > 262.0)
         Predict: -15.158907145046076
       Else (feature 9 > 289.0)
        If (feature 9 <= 325.0)
         Predict: 119.20872029648999
        Else (feature 9 > 325.0)
         Predict: 44.46861334461837
      Else (feature 9 > 365.0)
       If (feature 9 <= 542.0)
        If (feature 9 <= 433.0)
         Predict: -53.98348749571781
        Else (feature 9 > 433.0)
         Predict: 2.7560445508270553E-14
       Else (feature 9 > 542.0)
        If (feature 0 <= 1.0)
         Predict: 208.9111111111112
        Else (feature 0 > 1.0)
         Predict: 42.73790849673208
    Else (feature 8 > 22.0)
     If (feature 9 <= 365.0)
      If (feature 9 <= 289.0)
       If (feature 9 <= 262.0)
        If (feature 9 <= 207.0)
         Predict: -13.9039846722839
        Else (feature 9 > 207.0)
         Predict: 5.774199520123213
       Else (feature 9 > 262.0)
        If (feature 8 <= 177.0)
         Predict: -79.09166861079237
        Else (feature 8 > 177.0)
         Predict: 6.736997790910579E-14
      Else (feature 9 > 289.0)
       If (feature 9 <= 325.0)
        If (feature 8 <= 124.0)
         Predict: 82.85094320667717
        Else (feature 8 > 124.0)
         Predict: 19.643398454033065
       Else (feature 9 > 325.0)
        If (feature 8 <= 124.0)
         Predict: 11.172273738034342
        Else (feature 8 > 124.0)
         Predict: -48.39306368178768
     Else (feature 9 > 365.0)
      If (feature 9 <= 433.0)
       If (feature 8 <= 177.0)
        If (feature 8 <= 124.0)
         Predict: -90.7769818343103
        Else (feature 8 > 124.0)
         Predict: -184.6391184573002
       Else (feature 8 > 177.0)
        If (feature 7 <= 11.0014)
         Predict: 50.17777777777784
        Else (feature 7 > 11.0014)
         Predict: -25.088888888888825
      Else (feature 9 > 433.0)
       If (feature 0 <= 2.0)
        If (feature 9 <= 542.0)
         Predict: 13.742144387598863
        Else (feature 9 > 542.0)
         Predict: 51.12070153970035
       Else (feature 0 > 2.0)
        If (feature 9 <= 542.0)
         Predict: -6.432493117599564
        Else (feature 9 > 542.0)
         Predict: -49.29519154873245
  Tree 2:
    If (feature 8 <= 22.0)
     If (feature 9 <= 141.0)
      If (feature 8 <= 10.0)
       If (feature 9 <= 46.0)
        If (feature 8 <= 4.0)
         Predict: -3.972481384490268
        Else (feature 8 > 4.0)
         Predict: 6.5052817893811525
       Else (feature 9 > 46.0)
        If (feature 8 <= 5.0)
         Predict: -24.680631247522324
        Else (feature 8 > 5.0)
         Predict: -11.154612096992182
      Else (feature 8 > 10.0)
       If (feature 9 <= 87.0)
        If (feature 8 <= 16.0)
         Predict: 11.94276654885304
        Else (feature 8 > 16.0)
         Predict: 23.19079637029924
       Else (feature 9 > 87.0)
        If (feature 9 <= 130.0)
         Predict: -3.7189703989704226
        Else (feature 9 > 130.0)
         Predict: 26.09333333333334
     Else (feature 9 > 141.0)
      If (feature 9 <= 365.0)
       If (feature 9 <= 289.0)
        If (feature 9 <= 262.0)
         Predict: -55.395096633919195
        Else (feature 9 > 262.0)
         Predict: 18.190688574055304
       Else (feature 9 > 289.0)
        If (feature 9 <= 325.0)
         Predict: -143.05046435578816
        Else (feature 9 > 325.0)
         Predict: -53.362336013542176
      Else (feature 9 > 365.0)
       If (feature 9 <= 542.0)
        If (feature 9 <= 433.0)
         Predict: 64.78018499486113
        Else (feature 9 > 433.0)
         Predict: -2.7560445508270553E-14
       Else (feature 9 > 542.0)
        If (feature 0 <= 1.0)
         Predict: -250.69333333333356
        Else (feature 0 > 1.0)
         Predict: -51.28549019607853
    Else (feature 8 > 22.0)
     If (feature 9 <= 365.0)
      If (feature 9 <= 289.0)
       If (feature 9 <= 262.0)
        If (feature 9 <= 207.0)
         Predict: 16.68478160674068
        Else (feature 9 > 207.0)
         Predict: -6.929039424147923
       Else (feature 9 > 262.0)
        If (feature 8 <= 177.0)
         Predict: 94.91000233295094
        Else (feature 8 > 177.0)
         Predict: 0.0
      Else (feature 9 > 289.0)
       If (feature 9 <= 325.0)
        If (feature 8 <= 124.0)
         Predict: -99.42113184801259
        Else (feature 8 > 124.0)
         Predict: -23.572078144839647
       Else (feature 9 > 325.0)
        If (feature 8 <= 124.0)
         Predict: -13.406728485641175
        Else (feature 8 > 124.0)
         Predict: 58.071676418145216
     Else (feature 9 > 365.0)
      If (feature 9 <= 433.0)
       If (feature 8 <= 177.0)
        If (feature 8 <= 124.0)
         Predict: 108.9323782011724
        Else (feature 8 > 124.0)
         Predict: 221.5669421487602
       Else (feature 8 > 177.0)
        If (feature 7 <= 11.0014)
         Predict: -60.213333333333516
        Else (feature 7 > 11.0014)
         Predict: 30.106666666666662
      Else (feature 9 > 433.0)
       If (feature 0 <= 2.0)
        If (feature 9 <= 542.0)
         Predict: -16.490573265118563
        Else (feature 9 > 542.0)
         Predict: -61.34484184764051
       Else (feature 0 > 2.0)
        If (feature 9 <= 542.0)
         Predict: 7.718991741119511
        Else (feature 9 > 542.0)
         Predict: 59.154229858478935


In [118]:
// model.save(sc, "myModelPath")
// val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
// avoid incomplete useless val tmp =1 
val tmp =1




Out[118]:
1

Reference

Spark MLlib Decision tree

https://spark.apache.org/docs/latest/mllib-decision-tree.html

Book

advanced analysis with Spark

Dataset

kaggle bike-sharing

https://www.kaggle.com/c/bike-sharing-demand


In [ ]: